{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
    "_kg_hide-input": true,
    "_kg_hide-output": true,
    "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/usst/anaconda3/envs/dxzpy/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n",
      "  return f(*args, **kwds)\n"
     ]
    }
   ],
   "source": [
    "import numpy as np # linear algebra\n",
    "import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n",
    "\n",
    "import os\n",
    "from tqdm import tqdm\n",
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "_kg_hide-input": true,
    "_kg_hide-output": true,
    "_uuid": "254dffaff56ab4c8536c59c63863947816fd0e67"
   },
   "outputs": [],
   "source": [
    "label_names = {\n",
    "    0:\"apple healthy（苹果健康）\",\n",
    "    1:\"apple_Scab general（苹果黑星病一般）\",\n",
    "    2:\"apple_Scab serious（苹果黑星病严重）\",\n",
    "    3:\"apple Frogeye Spot（苹果灰斑病）\",\n",
    "    4:\"Cedar Apple Rust  general（苹果雪松锈病一般）\",\n",
    "    5:\"Cedar Apple Rust serious（苹果雪松锈病严重）\",\n",
    "    6:\"Cherry healthy（樱桃健康）\",\n",
    "    7:\"Cherry_Powdery Mildew  general（樱桃白粉病一般）\",\n",
    "    8:\"Cherry_Powdery Mildew  serious（樱桃白粉病严重）\",\n",
    "    9:\"Corn healthy（玉米健康）\",\n",
    "    10:\"Cercospora zeaemaydis Tehon and Daniels general（玉米灰斑病一般）\",\n",
    "    11:\"Cercospora zeaemaydis Tehon and Daniels  serious（玉米灰斑病严重）\",\n",
    "    12:\"Puccinia polysora  general（玉米锈病一般）\",\n",
    "    13:\"Puccinia polysora serious（玉米锈病严重）\",\n",
    "    14:\"Corn Curvularia leaf spot fungus general（玉米叶斑病一般）\",\n",
    "    15:\"Corn Curvularia leaf spot fungus  serious（玉米叶斑病严重）\",\n",
    "    16:\"Maize dwarf mosaic virus（玉米花叶病毒病）\",\n",
    "    17:\"Grape heathy（葡萄健康）\",\n",
    "    18:\"Grape Black Rot Fungus general（葡萄黑腐病一般）\",\n",
    "    19:\"Grape Black Rot Fungus serious（葡萄黑腐病严重）\",\n",
    "    20:\"Grape Black Measles Fungus general（葡萄轮斑病一般）\",\n",
    "    21:\"Grape Black Measles Fungus serious（葡萄轮斑病严重）\",\n",
    "    22:\"Grape Leaf Blight Fungus general（葡萄褐斑病一般）\",\n",
    "    23:\"Grape Leaf Blight Fungus  serious（葡萄褐斑病严重）\",\n",
    "    24:\"Citrus healthy（柑桔健康）\",\n",
    "    25:\"Citrus Greening June  general（柑桔黄龙病一般）\",\n",
    "    26:\"Citrus Greening June  serious（柑桔黄龙病严重）\",\n",
    "    27:\"Peach healthy（桃健康）\",\n",
    "    28:\"Peach_Bacterial Spot general（桃疮痂病一般）\",\n",
    "    29:\"Peach_Bacterial Spot  serious（桃疮痂病严重）\",\n",
    "    30:\"Pepper healthy（辣椒健康）\",\n",
    "    31:\"Pepper scab general（辣椒疮痂病一般）\",\n",
    "    32:\"Pepper scab  serious（辣椒疮痂病严重）\",\n",
    "    33:\"Potato healthy（马铃薯健康）\",\n",
    "    34:\"Potato_Early Blight Fungus general（马铃薯早疫病一般）\",\n",
    "    35:\"Potato_Early Blight Fungus serious（马铃薯早疫病严重）\",\n",
    "    36:\"Potato_Late Blight Fungus general（马铃薯晚疫病一般）\",\n",
    "    37:\"Potato_Late Blight Fungus  serious（马铃薯晚疫病严重）\",\n",
    "    38:\"Strawberry healthy（草莓健康）\",\n",
    "    39:\"Strawberry_Scorch general（草莓叶枯病一般）\",\n",
    "    40:\"Strawberry_Scorch serious（草莓叶枯病严重）\",\n",
    "    41:\"tomato healthy（番茄健康）\",\n",
    "    42:\"tomato powdery mildew  general（番茄白粉病一般）\",\n",
    "    43:\"tomato powdery mildew  serious（番茄白粉病严重）\",\n",
    "    44:\"tomato Bacterial Spot Bacteria general（番茄疮痂病一般）\",\n",
    "    45:\"tomato Bacterial Spot Bacteria  serious（番茄疮痂病严重）\",\n",
    "    46:\"tomato_Early Blight Fungus general（番茄早疫病一般）\",\n",
    "    47:\"tomato_Early Blight Fungus  serious（番茄早疫病严重）\",\n",
    "    48:\"tomato_Late Blight Water Mold  general（番茄晚疫病菌一般）\",\n",
    "    49:\"tomato_Late Blight Water Mold serious（番茄晚疫病菌严重）\",\n",
    "    50:\"tomato_Leaf Mold Fungus general（番茄叶霉病一般）\",\n",
    "    51:\"tomato_Leaf Mold Fungus serious（番茄叶霉病严重）\",\n",
    "    52:\"tomato Target Spot Bacteria  general（番茄斑点病一般）\",\n",
    "    53:\"tomato Target Spot Bacteria  serious（番茄斑点病严重）\",\n",
    "    54:\"tomato_Septoria Leaf Spot Fungus  general（番茄斑枯病一般）\",\n",
    "    55:\"tomato_Septoria Leaf Spot Fungus  serious（番茄斑枯病严重）\",\n",
    "    56:\"tomato Spider Mite Damage general（番茄红蜘蛛损伤一般）\",\n",
    "    57:\"tomato Spider Mite Damage serious（番茄红蜘蛛损伤严重）\",\n",
    "    58:\"tomato YLCV Virus general（番茄黄化曲叶病毒病一般）\",\n",
    "    59:\"tomato YLCV Virus  serious（番茄黄化曲叶病毒病严重）\",\n",
    "    60:\"tomato Tomv（番茄花叶病毒病）\"\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "_uuid": "a899007205c1fb5941dd91c54afbb312b3b86416"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torchvision import transforms\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "import torchvision.models as models\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "\n",
    "#1. set random.seed\n",
    "import random \n",
    "seed = 34\n",
    "random.seed(seed)\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "torch.cuda.manual_seed_all(seed)\n",
    "torch.backends.cudnn.deterministic = True\n",
    "torch.backends.cudnn.benchmark = True\n",
    "\n",
    "from sedensenet import densenet161"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "_uuid": "5eaf9968b312a8ee5ce50851abf262bb19a0b7cb"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/media/usst/数据/DXZ/ai-challenger-pdr/sedensenet.py:111: UserWarning: nn.init.kaiming_normal is now deprecated in favor of nn.init.kaiming_normal_.\n",
      "  nn.init.kaiming_normal(m.weight.data)\n"
     ]
    }
   ],
   "source": [
    "use_gpu = True\n",
    "num_classes = 61\n",
    "\n",
    "model = densenet161()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "_uuid": "92c0b95aa1909737a42f04334aadee1701a85f8a"
   },
   "outputs": [],
   "source": [
    "model.classifier = nn.Sequential(\n",
    "    nn.Dropout(0.5),\n",
    "    nn.Linear(2208, num_classes),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "_uuid": "286bcd4b25a1db7269de40aeb621b6f998321536"
   },
   "outputs": [],
   "source": [
    "device_ids = [0,1]\n",
    "\n",
    "if use_gpu:\n",
    "    model = model.cuda(device_ids[0])\n",
    "    model = nn.DataParallel(model, device_ids=device_ids)\n",
    "model.load_state_dict(torch.load('sedensenet.pth'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "_uuid": "04258289ccd938bc12d423b8fb055cae8f96da42"
   },
   "outputs": [],
   "source": [
    "#path = '../input/ai-challenger-pdr/ai_challenger_pdr2018_testa_20181023/AgriculturalDisease_testA/images/'\n",
    "path = './test/testB/testB/images/'\n",
    "\n",
    "trans_valid = transforms.Compose([transforms.Resize(size=224),\n",
    "                            transforms.CenterCrop(size=224),\n",
    "                            transforms.ToTensor(),\n",
    "                            transforms.Normalize(mean=[0.47954108864506007, 0.5295650244021952, 0.39169756009537665],\n",
    "                                 std=[0.21481591229053462, 0.20095268035289796, 0.24845895286079178])])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "_uuid": "b29ddd13d873b776044fa51e2b32ddcdd70defb1"
   },
   "outputs": [],
   "source": [
    "def get_prediction(model, loader, valid=False):\n",
    "    prediction = np.array([])\n",
    "    model.module.eval()\n",
    "    for _, data in enumerate(loader):\n",
    "        if valid:\n",
    "            inputs,_ = data\n",
    "        else:\n",
    "            inputs = data\n",
    "        print('.', end='')\n",
    "        if use_gpu:\n",
    "            inputs = inputs.cuda()\n",
    "        outputs = model(inputs)\n",
    "        pred = torch.argmax(outputs.data, dim=1)\n",
    "        prediction = np.append(prediction, pred.cpu().numpy())\n",
    "    return prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "_uuid": "960137e7fb3a6835774fc56917f08256ccc3c842"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ".............................................................................................................................................."
     ]
    }
   ],
   "source": [
    "class TestDataset(Dataset):\n",
    "    def __init__(self, data_dir = './', transform=None):\n",
    "        super().__init__()\n",
    "        self.data_dir = data_dir\n",
    "        self.transform = transform\n",
    "        self.image_names = os.listdir(data_dir)\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.image_names)\n",
    "    \n",
    "    def __getitem__(self, index):\n",
    "        img_name = self.image_names[index]\n",
    "        img_path = os.path.join(self.data_dir, img_name)\n",
    "        try:\n",
    "            img = Image.open(img_path)\n",
    "        except OSError:\n",
    "            print('read with cv2')\n",
    "            img = Image.fromarray(cv2.imread(img_path))\n",
    "        image = img.convert('RGB')\n",
    "        if self.transform is not None:\n",
    "            image = self.transform(image)\n",
    "        return image\n",
    "\n",
    "dataset_test = TestDataset(data_dir=path, transform=trans_valid)\n",
    "loader_test = DataLoader(dataset = dataset_test, batch_size=32, shuffle=False, num_workers=0)\n",
    "\n",
    "test_prediction = get_prediction(model, loader_test)\n",
    "\n",
    "sub = pd.DataFrame(list(zip(dataset_test.image_names,test_prediction.astype(int))),\n",
    "                   columns=['image_id', 'disease_class'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "_uuid": "5600998d47250314852c333e8bb3708c8d110d00"
   },
   "outputs": [],
   "source": [
    "sub.to_json('testb2.json',orient='records',force_ascii=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
